from modules.activation import get_activation
import torch
import torch.nn as nn


class Estimator(nn.Module):
    def __init__(self, input_dim,
                 output_dim,
                 hidden_dims=[256, 128, 64],  # 传入的是[128, 64]
                 activation="elu",
                 **kwargs):
        super(Estimator, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        activation = get_activation(activation)
        estimator_layers = []
        estimator_layers.append(nn.Linear(self.input_dim, hidden_dims[0]))
        estimator_layers.append(activation)
        for l in range(len(hidden_dims)):
            if l == len(hidden_dims) - 1:
                estimator_layers.append(nn.Linear(hidden_dims[l], output_dim))
            else:
                estimator_layers.append(nn.Linear(hidden_dims[l], hidden_dims[l + 1]))
                estimator_layers.append(activation)
        self.estimator = nn.Sequential(*estimator_layers)

    def forward(self, input):
        return self.estimator(input)

    def inference(self, input):
        with torch.no_grad():
            return self.estimator(input)